前情提要: 昨天基本上已經把Demucs model的部分完成的70%,基本上只一點而已。
skip connections參考:
https://medium.com/@chkim345/addition-based-skip-connections-vs-concatenation-base-skip-connections-8e25de804fcb
https://www.zhihu.com/question/306213462
normalization參考: 我自己比較常看到的就是Z-Score Normalization
https://medium.com/@weidagang/demystifying-machine-learning-normalization-0cdb8b281234
這裡實作的是用add而不是concat,在類似這個model架構下我測試過concat和add,基本上效果差不多,但在其他應用場景可能還是要參考論文。
# 在Encoder的部分,用一個list存放每一層做完的結果
def forward(self, x):
skip = []
length = x.size(-1)
x = F.pad(x, (0, self.valid_length(length) - length))
print(x.size())
for idx, enc in enumerate(self.encoder):
x = enc(x)
skip.append(x)
print(f'idx: {idx}, x: {x.size()}')
skip = skip[::-1] # 這裡跟當初創Decoder一樣,直接將list做翻轉
return x, skip
# 在Decoder當中,將每一層的輸出讀出來,跟原先的x做add
def forward(self, x, skip):
for idx, dec in enumerate(self.decoder):
print(f'x: {x.size()}, skip: {skip[idx].size()}')
x = x + skip[idx]
x = dec(x)
print(f'idx: {idx}, x: {x.size()}')
return x
# 在Demucs當中,添加normalize,以及傳入skip的部分
def forward(self, x):
# 做 normalize
if self.normalize:
mean = x.mean(dim=(1, 2), keepdim=True)
std = x.std(dim=(1, 2), keepdim=True)
x = (x - mean) / (1e-5 + std)
length = x.size(-1)
x, skip = self.encoder(x)
x = x.permute(0, 2, 1) # 用自己的電腦跑rearrange會稍微卡住
# x = rearrange(x, 'b c l -> b l c') # permute比較快
x = self.bottleneck(x)
x = x.permute(0, 2, 1)
# x = rearrange(x, 'b l c -> b c l')
x = self.decoder(x, skip)
x = x[..., :length] # 回傳最一開始audio的長度
return std * x + mean
今天太忙加上沒有事先寫好,所以趕快來補。
今天就更新到這囉~~